import sys
import os
import mxnet as mx
import symbol_utils
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from config import config


def Act(data, act_type, name):
    #ignore param act_type, set it in this function
    if act_type == 'prelu':
        body = mx.sym.LeakyReLU(data=data, act_type='prelu', name=name)
    else:
        body = mx.sym.Activation(data=data, act_type=act_type, name=name)
    return body


def Conv(data,
         num_filter=1,
         kernel=(1, 1),
         stride=(1, 1),
         pad=(0, 0),
         num_group=1,
         name=None,
         suffix=''):
    conv = mx.sym.Convolution(data=data,
                              num_filter=num_filter,
                              kernel=kernel,
                              num_group=num_group,
                              stride=stride,
                              pad=pad,
                              no_bias=True,
                              name='%s%s_conv2d' % (name, suffix))
    bn = mx.sym.BatchNorm(data=conv,
                          name='%s%s_batchnorm' % (name, suffix),
                          fix_gamma=False,
                          momentum=config.bn_mom)
    act = Act(data=bn,
              act_type=config.net_act,
              name='%s%s_relu' % (name, suffix))
    return act


def Linear(data,
           num_filter=1,
           kernel=(1, 1),
           stride=(1, 1),
           pad=(0, 0),
           num_group=1,
           name=None,
           suffix=''):
    conv = mx.sym.Convolution(data=data,
                              num_filter=num_filter,
                              kernel=kernel,
                              num_group=num_group,
                              stride=stride,
                              pad=pad,
                              no_bias=True,
                              name='%s%s_conv2d' % (name, suffix))
    bn = mx.sym.BatchNorm(data=conv,
                          name='%s%s_batchnorm' % (name, suffix),
                          fix_gamma=False,
                          momentum=config.bn_mom)
    return bn


def ConvOnly(data,
             num_filter=1,
             kernel=(1, 1),
             stride=(1, 1),
             pad=(0, 0),
             num_group=1,
             name=None,
             suffix=''):
    conv = mx.sym.Convolution(data=data,
                              num_filter=num_filter,
                              kernel=kernel,
                              num_group=num_group,
                              stride=stride,
                              pad=pad,
                              no_bias=True,
                              name='%s%s_conv2d' % (name, suffix))
    return conv


def DResidual(data,
              num_out=1,
              kernel=(3, 3),
              stride=(2, 2),
              pad=(1, 1),
              num_group=1,
              name=None,
              suffix=''):
    conv = Conv(data=data,
                num_filter=num_group,
                kernel=(1, 1),
                pad=(0, 0),
                stride=(1, 1),
                name='%s%s_conv_sep' % (name, suffix))
    conv_dw = Conv(data=conv,
                   num_filter=num_group,
                   num_group=num_group,
                   kernel=kernel,
                   pad=pad,
                   stride=stride,
                   name='%s%s_conv_dw' % (name, suffix))
    proj = Linear(data=conv_dw,
                  num_filter=num_out,
                  kernel=(1, 1),
                  pad=(0, 0),
                  stride=(1, 1),
                  name='%s%s_conv_proj' % (name, suffix))
    return proj


def Residual(data,
             num_block=1,
             num_out=1,
             kernel=(3, 3),
             stride=(1, 1),
             pad=(1, 1),
             num_group=1,
             name=None,
             suffix=''):
    identity = data
    for i in range(num_block):
        shortcut = identity
        conv = DResidual(data=identity,
                         num_out=num_out,
                         kernel=kernel,
                         stride=stride,
                         pad=pad,
                         num_group=num_group,
                         name='%s%s_block' % (name, suffix),
                         suffix='%d' % i)
        identity = conv + shortcut
    return identity


def get_symbol():
    num_classes = config.emb_size
    print('in_network', config)
    fc_type = config.net_output
    data = mx.symbol.Variable(name="data")
    data = data - 127.5
    data = data * 0.0078125
    blocks = config.net_blocks
    conv_1 = Conv(data,
                  num_filter=64,
                  kernel=(3, 3),
                  pad=(1, 1),
                  stride=(2, 2),
                  name="conv_1")
    if blocks[0] == 1:
        conv_2_dw = Conv(conv_1,
                         num_group=64,
                         num_filter=64,
                         kernel=(3, 3),
                         pad=(1, 1),
                         stride=(1, 1),
                         name="conv_2_dw")
    else:
        conv_2_dw = Residual(conv_1,
                             num_block=blocks[0],
                             num_out=64,
                             kernel=(3, 3),
                             stride=(1, 1),
                             pad=(1, 1),
                             num_group=64,
                             name="res_2")
    conv_23 = DResidual(conv_2_dw,
                        num_out=64,
                        kernel=(3, 3),
                        stride=(2, 2),
                        pad=(1, 1),
                        num_group=128,
                        name="dconv_23")
    conv_3 = Residual(conv_23,
                      num_block=blocks[1],
                      num_out=64,
                      kernel=(3, 3),
                      stride=(1, 1),
                      pad=(1, 1),
                      num_group=128,
                      name="res_3")
    conv_34 = DResidual(conv_3,
                        num_out=128,
                        kernel=(3, 3),
                        stride=(2, 2),
                        pad=(1, 1),
                        num_group=256,
                        name="dconv_34")
    conv_4 = Residual(conv_34,
                      num_block=blocks[2],
                      num_out=128,
                      kernel=(3, 3),
                      stride=(1, 1),
                      pad=(1, 1),
                      num_group=256,
                      name="res_4")
    conv_45 = DResidual(conv_4,
                        num_out=128,
                        kernel=(3, 3),
                        stride=(2, 2),
                        pad=(1, 1),
                        num_group=512,
                        name="dconv_45")
    conv_5 = Residual(conv_45,
                      num_block=blocks[3],
                      num_out=128,
                      kernel=(3, 3),
                      stride=(1, 1),
                      pad=(1, 1),
                      num_group=256,
                      name="res_5")
    conv_6_sep = Conv(conv_5,
                      num_filter=512,
                      kernel=(1, 1),
                      pad=(0, 0),
                      stride=(1, 1),
                      name="conv_6sep")

    fc1 = symbol_utils.get_fc1(conv_6_sep, num_classes, fc_type)
    return fc1
